{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer\n",
    "import transformers\n",
    "import torch\n",
    "import oss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:12<00:00,  6.19s/it]\n"
     ]
    }
   ],
   "source": [
    "base_model = \"YOUR_LLAMA_MODEL_PATH\"\n",
    "tokenizer = LlamaTokenizer.from_pretrained(base_model)\n",
    "model = LlamaForCausalLM.from_pretrained(\n",
    "    base_model,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"auto\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Xbox 360 Guitar Hero 5 Guitar Bundle',\n",
       " 'Cake Mania 2 - Nintendo DS',\n",
       " 'Ultimate I Spy - Nintendo Wii',\n",
       " 'Call of Duty: Ghosts Prestige Edition - Xbox 360',\n",
       " 'Call of Duty World at War Collector&#39;s Edition - Xbox 360']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk\n",
    "model.config.bos_token_id = 1\n",
    "model.config.eos_token_id = 2\n",
    "model.eval()\n",
    "f = open('YOUR ITEM NAME FILE', 'r')\n",
    "# the format of the item name file is \n",
    "# item_name item_id\n",
    "# A 0\n",
    "# B 1\n",
    "lines = f.readlines()\n",
    "f.close()\n",
    "text = [_.split('\\t')[0].strip(\" \").strip('\\\"') for _ in lines] # remove the leading and trailing spaces and quotess make sure this preprocess is the same as the prediction\n",
    "tokenizer.padding_side = \"left\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1088it [05:35,  3.24it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "def batch(list, batch_size=1):\n",
    "    chunk_size = (len(list) - 1) // batch_size + 1\n",
    "    for i in range(chunk_size):\n",
    "        yield list[batch_size * i: batch_size * (i + 1)]\n",
    "item_embedding = []\n",
    "for i, batch_input in tqdm(enumerate(batch(text, 16))):\n",
    "    input = tokenizer(batch_input, return_tensors=\"pt\", padding=True)\n",
    "    input_ids = input.input_ids\n",
    "    attention_mask = input.attention_mask\n",
    "    outputs = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)\n",
    "    hidden_states = outputs.hidden_states\n",
    "    item_embedding.append(hidden_states[-1][:, -1, :].detach().cpu())\n",
    "    # break\n",
    "item_embedding = torch.cat(item_embedding, dim=0)\n",
    "torch.save(item_embedding, 'item_embedding.pt')\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([17408, 4096])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "item_embedding.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "alpaca_lora",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
